19. Saliency Maps

  • motivating application

  • saliency maps

  • captum package

Learning outcomes

  • Analyze large-scale foundation models using methods like sparse autoencoders and describe their relevance to problems of model control and AI-driven design.
  • Within a specific application context, evaluate the trade-offs associated with competing interpretable machine learning techniques.

Readings

  • Adebayo, J., Gilmer, J., Muelly, M., Goodfellow, I., Hardt, M., & Kim, B. (2018). Sanity Checks for Saliency Maps. doi:10.48550/ARXIV.1810.03292
  • Captum - Model Interpretability for Pytorch. Getting started with Captum - Titanic Data Analysis. (n.d.). https://captum.ai/tutorials/Titanic_Basic_Interpret

Motivating Case Study

India lacks centralized renewable energy databases.

Questions: - Solar plant counts over time? - Geographic distribution?

Needed for energy transition planning r Citep(bib, c("Ortiz2022", "solar_data")).

Scientific Goal

Question. Can we locate solar farms in satellite imagery?

Data. Sentinel 2 (10-60m/px, 12 bands) + OpenStreetMap labels. 1363 labeled sites.

Constraint. Must isolate pixels responsible for predictions.

Cautionary Example

Models can pick up on spurious correlations (like whether there is snow in the background). Leads to generalization failure.

Statistical Formulation

Labels. \(y \in \{0, 1\}^{256 \times 256}\) (pixel labels)

Predictors. \(x \in \mathbb{R}^{10 \times 256 \times 256}\) (spectral patch)

Model. \(f(x; \theta) \to [0, 1]^{256 \times 256}\) (probability mask)

Goal. Find attributions \(\varphi \in \mathbb{R}^{256 \times 256}\) explaining \(f(x)\).

Saliency Maps

Perturbation Methods

Perturb input \(x\) to measure prediction changes.

Occlusion. Mask patches, observe \(\Delta f(x)\)

Feature Ablation. Replace features with baseline, measure marginal effects

Gradient Saliency

\[E_{\text{grad}}(x) = \frac{\partial S_c(x)}{\partial x}\]

Identifies pixels where small changes maximally affect class \(c\) score.

Refinement: \(x \odot \frac{\partial S_c}{\partial x}\) accounts for feature scale.

Computation

  • Mechanism. Efficiently computed via a single backpropagation pass using torch.autograd.
  • Aggregation. For multi-channel inputs (e.g., RGB or 12-band spectral), gradients are typically reduced to a single map using: \[M(x) = \max_j | \nabla_{x_j} f(x) |\] where \(j\) is the channel index.

Gradient Limitations

Saturation. Flat regions (ReLU) yield zero gradients despite importance.

Noise. High-frequency artifacts from shattered gradients.

Instability. Prediction-stable regions show large gradient changes.

Integrated Gradients

Satisfies completeness: \(\sum \text{IG}_i(x) = f(x) - f(x')\)

\[\text{IG}_{i}(x) = \left(x_{i} - x_{i}'\right) \int_{\alpha=0}^{1} \frac{\partial f\left(x' + \alpha\left(x - x'\right)\right)}{\partial x_{i}} d\alpha\]

\(x'\): baseline (e.g., black image)

Aggregates gradients along path \(x' \to x\).

Integrated Gradients

Approximation

Riemann sum with \(m \in [50, 200]\) steps:

\[\text{IG}_i(x) \approx (x_i - x'_i) \sum_{k=1}^m \frac{\partial f(x' + \frac{k}{m}(x - x'))}{\partial x_i} \cdot \frac{1}{m}\]

Cost: \(\mathcal{O}(m)\) backpropagations.

captum package

PyTorch interpretability library:

  • captum.attr: Saliency, IG, DeepLift, SHAP
  • captum.metrics: Robustness
  • captum.concept: TCAV (next week)

Gradients in captum

from captum.attr import Saliency

slc = Saliency(model)
attr = slc.attribute(input_tensor, target=class_idx)

Exercise: Pseudocode

Goal: Write the loop for Integrated Gradients.

    Initialize total_gradients = 0.
    Define baseline and steps = 50.
    For k in range 1 to steps:
        ### ???

    attribution = (input - baseline) * (total_gradients / steps)

SHAP

SHAP for Images

SHAP treats pixels as the players in a cooperative game.

\[\varphi_i = \sum_{S \subseteq N \setminus \{i\}} \frac{|S|!(|N|-|S|-1)!}{|N|!} [v(S \cup \{i\}) - v(S)]\]

Problem: \(256 \times 256\) image has \(2^{65536}\) coalitions.

Approximations: - KernelSHAP: weighted regression - Superpixels: group pixels to reduce \(|N|\)

SHAP in captum

Captum provides KernelShap and ShapleyValueSampling.

  • Feature Masking: Users can define a feature_mask to group pixels into semantic regions (e.g., all pixels in a solar panel).
  • Efficiency: Instead of 2N samples, it uses a limited budget of permutations to estimate the values.

Feature Neighborhoods

  • Another idea is to restrict the collection of sets in the summation.

  • This is most natural when there is a notion of distance between features. For example, for a word at the start of a sentence, don’t bother with sets of words near the end.

\(L\)-Shapley (Local Approximation)

Let \(N_k(i)\) = features within distance \(k\) of \(i\).

\[\varphi^{L}(f, i) = \frac{1}{|N_k(i)|} \sum_{S \in N_k(i)} \frac{1}{\binom{|N_k(i)|-1}{|S|-1}} [v(S) - v(S \setminus \{i\})]\]

Efficient when \(k \ll |N|\).

Exercise: Interpretability Goals vs. Method